#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
Function:
The replay function entry
"""

import os
import re
import stat
import sys

import const_var


def gen_tiling(tiling_header_file: str, tiling_file_out: str):
    if not os.path.exists(tiling_header_file):
        print("warning: no userdef tiling header file: ", tiling_header_file)
        return
    print("generate tiling def header file: ", tiling_file_out)
    tmp_name = os.path.splitext(os.path.basename(tiling_header_file))[0].upper()
    tiling_source = "#ifndef __{}_H__\n".format(tmp_name)
    tiling_source += "#define __{}_H__\n\n".format(tmp_name)
    tiling_source += "#include <cstdint>\n"
    tiling_source += "#include <cstring>\n\n"
    tiling_source += '#include "kernel_tiling/kernel_tiling.h"\n\n'
    end_source = ""
    pattern = re.compile(r"[(](.*)[)]", re.S)
    with open(tiling_header_file, "r") as fd:
        lines = fd.readlines()
        for line in lines:
            line = line.strip()
            if line.startswith("BEGIN_TILING_DATA_DEF"):
                tiling_source += "#pragma pack(1)\n"
                tiling_source += "struct "
                struct_def = re.findall(pattern, line)[0]
                tiling_source += struct_def + " {\n"
            elif line.startswith("TILING_DATA_FIELD_DEF_ARR"):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(",")
                tiling_source += "    {} {}[{}] = {{}};\n".format(
                    fds[0].strip(), fds[2].strip(), fds[1].strip()
                )
            elif line.startswith("TILING_DATA_FIELD_DEF_STRUCT"):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(",")
                tiling_source += "    {} {};\n".format(fds[0].strip(), fds[1].strip())
            elif line.startswith("TILING_DATA_FIELD_DEF"):
                field_params = re.findall(pattern, line)[0]
                fds = field_params.split(",")
                tiling_source += "    {} {} = 0;\n".format(
                    fds[0].strip(), fds[1].strip()
                )
            elif line.startswith("END_TILING_DATA_DEF"):
                tiling_source += "};\n"
                tiling_source += "#pragma pack()\n\n"
                tiling_source += "#ifdef __NPU_TILING__\n"
                tiling_source += "inline [aicore] void Init{stru}(const __gm__ uint8_t* tiling, {stru}* const_data)\n".format(
                    stru=struct_def
                )
                tiling_source += "{\n"
                tiling_source += "    const __gm__ uint32_t *src = (const __gm__ uint32_t *)tiling;\n"
                tiling_source += "    uint32_t *dst = (uint32_t *)const_data;\n"
                tiling_source += "    for (auto i = 0; i < sizeof({}) / 4; i++) *(dst + i) = *(src + i);\n".format(
                    struct_def
                )
                tiling_source += "}\n"
                tiling_source += "#else\n"
                tiling_source += "inline void Init{stru}(uint8_t* tiling, {stru}* const_data)\n".format(
                    stru=struct_def
                )
                tiling_source += "{\n"
                tiling_source += "    uint64_t *src = (uint64_t *)tiling;\n"
                tiling_source += "    uint64_t *dst = (uint64_t *)const_data;\n"
                tiling_source += "    for (auto i = 0; i < sizeof({}) / 8; i++) *(dst + i) = *(src + i);\n".format(
                    struct_def
                )
                tiling_source += "}\n"
                tiling_source += "#endif\n\n"
                end_source = """
#undef GET_TILING_DATA
#define GET_TILING_DATA(tiling_data, tiling_arg) \\
{stru} tiling_data; \\
Init{stru}(tiling_arg, &tiling_data)\n
""".format(
                    stru=struct_def
                )
    tiling_source += end_source
    tiling_source += "#endif"
    with os.fdopen(
        os.open(tiling_file_out, const_var.WFLAGS, const_var.WMODES), "w"
    ) as ofd:
        ofd.write(tiling_source)


if __name__ == "__main__":
    if len(sys.argv) <= 2:
        raise RuntimeError("arguments must greater than 2")
    gen_tiling(sys.argv[1], sys.argv[2])
